# -*- coding: utf-8 -*-
'''
Created on 2017年4月19日

@author: ZhuJiahui506
'''
import os
import time
import numpy as np
import random
from file_utils.file_reader import read_to_1d_list
from topic_utils.distribution_util import get_topic_proportion
from file_utils.file_writer import quick_write_1d_to_text


def get_selected_embeddings(read_filename, selected_word_list, embedding_dis):
    '''
    从中文词向量文件中获取词汇列表对应的词向量
    :param read_filename: 词向量文件
    :param selected_word_list: 词汇列表
    :param embedding_dis: 词向量维数
    :return 词汇列表中的词汇构成的词向量矩阵(numpy 2d array)
    '''
    
    # 初始化矩阵
    selected_embeddings = np.zeros((len(selected_word_list), embedding_dis), dtype=float)
    with open(read_filename, 'r', encoding='utf-8') as f:
        for each_line in f:
            split_line = each_line.strip().split()
            this_word = split_line[0]
            try:
                word_index = selected_word_list.index(this_word)
                selected_embeddings[word_index, :] = [float(x) for x in split_line[1:]]
            except ValueError:
                pass
    
    return selected_embeddings


def generate_topic_triples(read_directory1, read_directory2, read_directory3, read_filename, write_directory):
    '''
    进行主题关联 由此产生<主题, 关系, 主题>三元组
    主题关联信息来源于主题向量之间的余弦相似度计算
    主题向量根据主题-词汇分布中所有词汇的词向量加权求和所得
    :param read_directory1: 文档-主题分布文件路径
    :param read_directory2: 主题-词汇分布文件路径
    :param read_directory3: 词汇列表文件路径
    :param read_filename: 中文词向量文件
    :param write_directory: 结果文件路径
    '''
    
    embedding_dis = 200  # 词向量(主题向量)维数
    start_batch = 183  # 开始片
    
    t_gamma = 0.8  # 延续关系阈值
    t_delta = 0.5  # 合并/分裂关系阈值
    t_eps = 0.5  # 前后主题比阈值

    former_doc_dis = np.loadtxt(read_directory1 + '/' + str(start_batch) + '.txt')
    former_topic_dis = np.loadtxt(read_directory2 + '/' + str(start_batch) + '.txt')
    former_word_list = [each.split()[0] for each in read_to_1d_list(read_directory3 + '/' + str(start_batch) + '.txt')]
    
    _, former_topic_support = get_topic_proportion(former_doc_dis)  # 前一片的各个主题的支持度(在文档中出现的频率)
    former_word_embeddings = get_selected_embeddings(read_filename, former_word_list, embedding_dis)
    former_topic_embeddings = np.dot(former_topic_dis, former_word_embeddings)
    
    former_topic_num = former_topic_dis.shape[0]
    former_norm = np.linalg.norm(former_topic_embeddings, 2, 1)  # 按行计算2范数
    
    topic_relation_triples = []  # 主题-关系三元组(1d str list)
    topic_list = []  # 所有不同主题的列表(1d str list)
    relation_list = ["延续", "合并", "分裂"]  # 所有关系的列表

    for i in range(start_batch + 1, 208):
        
        latter_doc_dis = np.loadtxt(read_directory1 + '/' + str(i) + '.txt')
        latter_topics_dis = np.loadtxt(read_directory2 + '/' + str(i) + '.txt')
        latter_word_list = [each.split()[0] for each in read_to_1d_list(read_directory3 + '/' + str(i) + '.txt')]

        _, latter_topic_support = get_topic_proportion(latter_doc_dis)  # 前一片的各个主题的支持度(在文档中出现的频率)
        latter_word_embeddings = get_selected_embeddings(read_filename, latter_word_list, embedding_dis)
        latter_topic_embeddings = np.dot(latter_topics_dis, latter_word_embeddings)

        latter_topic_num = latter_topics_dis.shape[0]
        latter_norm = np.linalg.norm(latter_topic_embeddings, 2, 1)  # 按行计算2范数
        
        # 前后片之间的主题相似度矩阵 采用余弦相似度
        cos_sim_matrix = np.zeros([former_topic_num, latter_topic_num], dtype=float)
        for j in range(former_topic_num):
            for k in range(latter_topic_num):
                cos_sim_matrix[j, k] = np.true_divide(
                    np.dot(former_topic_embeddings[j], latter_topic_embeddings[k]), (former_norm[j] * latter_norm[k])
                    )

        # 判断延续
        for j in range(former_topic_num):
            for k in range(latter_topic_num):
                topic_division = 0.0  # 前后主题比
                if (former_topic_support[j] > 0.0001):
                    topic_division = np.true_divide(latter_topic_support[k], former_topic_support[j])
                if (cos_sim_matrix[j, k] > t_gamma) and (topic_division > (1 - t_eps)):
                    # 存在延续
                    head_topic = "topic_" + str(i - 1) + "_" + str(j)
                    tail_topic = "topic_" + str(i) + "_" + str(k)
                    topic_relation_triples.append(" ".join([head_topic, relation_list[0], tail_topic]))
                    
                    if head_topic not in topic_list:
                        topic_list.append(head_topic)
                    if tail_topic not in topic_list:
                        topic_list.append(tail_topic)
        
        # 判断合并
        for k in range(latter_topic_num):
            merge_index = []  # 前一片中的有合并关系的主题编号
            for j in range(former_topic_num):
                if (cos_sim_matrix[j, k] > t_delta) and (cos_sim_matrix[j, k] <= t_gamma):
                    merge_index.append(j)
            
            if len(merge_index) >= 2:
                # 计算前一片中所有待合并主题的加权主题
                weight_numerator = np.zeros(embedding_dis)
                weight_denominator = 0.0
                for each in merge_index:
                    weight_numerator += (former_topic_support[each] * former_topic_embeddings[each])
                    weight_denominator += former_topic_support[each]
                weight_topic_embedding = np.true_divide(weight_numerator, weight_denominator)
                
                # 加权主题与后一片中当前主题的余弦相似度
                merged_sim = np.true_divide(np.dot(weight_topic_embedding, latter_topic_embeddings[k]), 
                                            (np.linalg.norm(weight_topic_embedding * latter_norm[k])))
                
                if (merged_sim > t_gamma):
                    # 存在合并
                    for each in merge_index:
                        head_topic = "topic_" + str(i - 1) + "_" + str(each)
                        tail_topic = "topic_" + str(i) + "_" + str(k)
                        topic_relation_triples.append(" ".join([head_topic, relation_list[1], tail_topic]))
                    
                        if head_topic not in topic_list:
                            topic_list.append(head_topic)
                        if tail_topic not in topic_list:
                            topic_list.append(tail_topic)
        
        # 判断分裂
        for j in range(former_topic_num):
            split_index = []  # 后一片中的有分裂关系的主题编号
            for k in range(latter_topic_num):
                if (cos_sim_matrix[j, k] > t_delta) and (cos_sim_matrix[j, k] <= t_gamma):
                    split_index.append(k)
            
            if len(split_index) >= 2:
                # 计算后一片中所有待分裂主题的加权主题
                weight_numerator = np.zeros(embedding_dis)
                weight_denominator = 0.0
                for each in split_index:
                    weight_numerator += (latter_topic_support[each] * latter_topic_embeddings[each])
                    weight_denominator += latter_topic_support[each]
                weight_topic_embedding = np.true_divide(weight_numerator, weight_denominator)
                
                # 加权主题与前一片中当前主题的余弦相似度
                split_sim = np.true_divide(np.dot(former_topic_embeddings[j], weight_topic_embedding), 
                                            (former_norm[j] * np.linalg.norm(weight_topic_embedding)))
                
                if (split_sim > t_gamma):
                    # 存在分裂
                    for each in split_index:
                        head_topic = "topic_" + str(i - 1) + "_" + str(j)
                        tail_topic = "topic_" + str(i) + "_" + str(each)
                        topic_relation_triples.append(" ".join([head_topic, relation_list[2], tail_topic]))
                    
                        if head_topic not in topic_list:
                            topic_list.append(head_topic)
                        if tail_topic not in topic_list:
                            topic_list.append(tail_topic)
        
        # 将后一片作为当前片 迭代推进
        former_doc_dis = latter_doc_dis
        former_topic_dis = latter_topics_dis
        former_word_list = latter_word_list
    
        former_topic_support = latter_topic_support
        former_word_embeddings = latter_word_embeddings
        former_topic_embeddings = latter_topic_embeddings
        
        former_topic_num = latter_topic_num
        former_norm = latter_norm
    
    # 构造训练集 验证集 测试集
    train_num = int(len(topic_relation_triples) * 0.8)
    train_triples = random.sample(topic_relation_triples, train_num)
    validate_num = int(len(topic_relation_triples) / 5)
    validate_triples = random.sample(topic_relation_triples, validate_num)
    test_num = int(len(topic_relation_triples) / 5)
    test_triples = random.sample(topic_relation_triples, test_num)
    
    # 写入文件
    quick_write_1d_to_text(write_directory + '/topic_list.txt', topic_list)
    quick_write_1d_to_text(write_directory + '/relation_list.txt', relation_list)
    quick_write_1d_to_text(write_directory + '/all_triples.txt', topic_relation_triples)
    quick_write_1d_to_text(write_directory + '/train_triples.txt', train_triples)
    quick_write_1d_to_text(write_directory + '/validate_triples.txt', validate_triples)
    quick_write_1d_to_text(write_directory + '/test_triples.txt', test_triples)


if __name__ == '__main__':
    
    start = time.clock()
    now_directory = os.getcwd()
    root_directory = os.path.dirname(now_directory) + '/'
    read_directory1 = root_directory + 'dataset/LDA/feed_topic100'
    read_directory2 = root_directory + 'dataset/LDA/topic_word100'
    read_directory3 = root_directory + 'dataset/text_model/select_words'
    read_filename = root_directory + 'dataset/zhwiki_word_embedding.txt'
    
    write_directory = root_directory + 'dataset/trans_e_data'
    
    if (not(os.path.exists(write_directory))):
        os.mkdir(write_directory)
    
    generate_topic_triples(read_directory1, read_directory2, read_directory3, read_filename, write_directory)
    print('Total time %f seconds' % (time.clock() - start))
